# -*- coding: utf-8 -*-
"""
Created on Fri Apr 15 14:22:23 2022

"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
import numpy as np
from torch.autograd import Variable
import time
import pandas as pd

#import networkx as nx
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

from tensorboardX import SummaryWriter
import pickle
from tqdm import tqdm
from models import PhyCRNet,MaskedMSELoss,MaskedL1Loss,weightMSELoss
from data_utils import PrepareConcDataset
from utils import plot_conc,read_logs,count_weight
from shap_utils import plot_shap_sig,plot_shap_collect,p2v,read_svg_path,create_png_path
from shap_utils import get_shap_value,process_shap,svg2png_cairo,svg2png_aw
import shap
import cv2

#%%
args={}
args['cuda'] = torch.cuda.is_available()
#if cuda is available use cuda
if args['cuda']:
    device = torch.device('cuda')


best_mae=100
best_model_path=0

#%%
def write_pkl(write_data,pkl_path):
	pickle.dump(write_data,open(pkl_path,'wb'))

def read_pkl(pkl_file_path):
	my_data=pickle.load(open(pkl_file_path,'rb'))
	return my_data
def make_dir(dir_path):
	if not os.path.exists(dir_path):
		os.makedirs(dir_path)

#%%
log_path='./model/'
# dirs={
#       "log_dirs":log_dirs,
#       "model_dirs":model_dirs,
#       "args_dirs":args_dirs,
#       "events_dirs":events_dirs,
#       "orders":orders
#       }

dirs=read_logs(log_path)

#%%
for index in range(len(dirs['orders'])):

    with open(dirs['args_dirs'][index],'rb') as f:
        args=pickle.load(f)
        
    train_input,train_label1,train_label1_coe,train_label2,train_label2_mask,\
    test_input,test_label1,test_label1_coe,test_label2,test_label2_mask,\
    mmin_list,dev_list,conc_name,edge_list,count,node_coord,\
    time_scaler,scaler_Q,scaler_H=PrepareConcDataset(args['time_window'],args['res'],
                                                     args['pred_day'],args['buffer'],
                                                     args['gap'],args['test_day'],
                                                     args['val_ratio'],flag='plot')
    
    [batch_size,time_len,input_channels,h,w] = train_input.size()
    write_pkl(edge_list,'井名_10_10网格索引.pkl')
    #%%
    '''
    
    model_path=dirs['model_dirs'][index]
    
    model = PhyCRNet(args['input_dim'],
                     args['hidden_dim'],
                     args['output_dim'],
                     args['input_kernel_size'],
                     args['input_stride'],
                     args['padding_mode'],
                     args['num_layers'], 
                     args['dropout'],
                     args['warmup_updates'],
                     args['tot_updates'],
                     args['peak_lr'],
                     args['end_lr'],
                     args['power'],
                     args['weight_decay'],
                     edge_list,
                     scaler_Q,
                     mmin_list,
                     dev_list)
    
    model=model.to(device=device)
    model.load_state_dict(torch.load(model_path))

    count_weight(model)
    all_mae=plot_conc(args,model,dirs['log_dirs'][index],device,'sig')
    #'''
    #%%
    '''
    if args['cuda']:
        device=torch.device('cuda')
    else:
        device=torch.device('cpu')
        
    train_input=train_input.to(device=device)
    train_label1=train_label1.to(device=device)
    train_label1_coe=train_label1_coe.to(device=device)
    train_label2=train_label2.to(device=device)
    train_label2_mask=train_label2_mask.to(device=device)
    test_input=test_input.to(device=device)
    test_label1=test_label1.to(device=device)
    test_label1_coe=test_label1_coe.to(device=device)
    test_label2=test_label2.to(device=device)
    test_label2_mask=test_label2_mask.to(device=device)
    
    #'''
    #%%
    log_dir=dirs['log_dirs'][index]
    #get_shap_value(model,train_input,test_input,200,log_dir)
    
# '''
    peak_eg_values,peak_eg_values_var=process_shap(log_dir,time_len,input_channels,h,w,count,edge_list)
    write_pkl(peak_eg_values,'对时间_流量_酸浓度的eg值.pkl')
    write_pkl(peak_eg_values,'对时间_流量_酸浓度的eg_var值.pkl')
    #peak_eg_values=read_pkl('model/20220708-104119/peak_eg_values.pkl')
    #peak_eg_values_var=read_pkl('model/20220708-104119/peak_eg_values_var.pkl')
    
    loot_figure_path='./figures/'
    log_dir=dirs['log_dirs'][index][8:]
    file_type='svg'
    folder_name='test_sig'+'_'+file_type
    topk=5
    peak_dates=np.arange(0,60,1)
    
    for peak_date in peak_dates:
        
        plot_shap_sig(edge_list,node_coord,loot_figure_path,log_dir,folder_name,file_type,peak_date,topk,
                      input_channels,time_len,peak_eg_values,peak_eg_values_var)
        
        # plot_shap_collect(edge_list,loot_figure_path,log_dir,peak_date,input_channels,time_len,peak_eg_values,peak_eg_values_var)

    

    png_folder_name='test_sig_png_1'
    svg_dirs=read_svg_path(loot_figure_path+log_dir+f'/{folder_name}/')
    png_dirs=[svg_dir.replace(f'/{folder_name}/',f'/{png_folder_name}/').replace('svg','png') for svg_dir in svg_dirs]
    
    svg2png_cairo(svg_dirs,png_dirs,edge_list,loot_figure_path,log_dir,png_folder_name,peak_dates)
    # svg2png_aw(svg_dirs,png_dirs,edge_list,loot_figure_path,log_dir,png_folder_name,peak_date)
            
    
    #%%
    temp_img=cv2.imread(png_dirs[0])
    fig_size=np.flip(np.array(temp_img.shape[:2]))
    fps=6
    p2v(fig_size, loot_figure_path, log_dir, png_dirs,png_folder_name,fps=fps)
    
# '''


